# %matplotlib inline

import numpy as np
import scipy.sparse as sparse
import scipy.stats as stats
import matplotlib
import matplotlib.pyplot as plt
import numpy.linalg as la
import random
import time
import sklearn.metrics.pairwise as skl
import scipy.io as io

kwargs = {"linewidth" : 2.5}
font = {"weight" : "normal",
                "size"   : 16}
matplotlib.rc("font", **font)

def gen_gaussian(m, n, mu, sigma):
    # generates an m by n matrix of N(mu, sigma^2) variables
    return np.random.normal(mu, sigma, (m, n))

def gen_countsketch(r, c):
    S = sparse.dok_matrix((r, c), dtype=np.float32)
    hashed = np.random.choice(r, c, replace=True)
    for j in range(c):
        S[hashed[j], j] = 2*np.random.randint(2) - 1

    return sparse.csr_matrix(S)

def alternating_minimization(left, right, update_left, update_right,
                                                        num_updates):
    iterates = [(left, right)]
    for i in range(num_updates):
        left = update_left(right)
        right = update_right(left)
        iterates.append((left, right))

    return iterates

def update_right(colweights, A, U, lmda, t):
    n, d = A.shape
    _, k = U.shape
    V = np.full((k, d), 1.00)
    reg = np.identity(k)*(lmda**(0.5))
    zeros = np.zeros((k, 1))
    S = gen_countsketch(t, n)

    for j in range(d):
        SDU = S @ (colweights[j] @ U)
        SDA = S @ (colweights[j] @ A[:, [j]])

        M = np.block([[SDU], [reg]])
        b = np.block([[SDA], [zeros]])

        V[:,j] = la.lstsq(M, b, rcond = None)[0].T

    return V

def update_left(rowweights, A, V, lmda, t):
    n, d = A.shape
    k, _ = V.shape
    U = np.full((n, k), 1.00)
    reg = np.identity(k)*(lmda**(0.5))
    zeros = np.zeros((k, 1))
    S = gen_countsketch(d, t)

    for i in range(n):
        VDS = (V @ rowweights[i]) @ S
        ADS = (A[[i], : ] @ rowweights[i]) @ S

        M = np.block([[VDS.T], [reg]])
        b = np.block([[ADS.T], [zeros]])

        U[i,:] = la.lstsq(M, b, rcond = None)[0].T

    return U


def nosketch_update_right(colweights, A, U, lmda, t):
    n, d = A.shape
    _, k = U.shape
    V = np.full((k, d), 1.00)
    reg = np.identity(k)*(lmda**(0.5))
    zeros = np.zeros((k, 1))

    for j in range(d):
        DU = (colweights[j] @ U)
        DA = (colweights[j] @ A[:, [j]])

        M = np.block([[DU], [reg]])
        b = np.block([[DA], [zeros]])

        V[:,j] = la.lstsq(M, b, rcond = None)[0].T

    return V

def nosketch_update_left(rowweights, A, V, lmda, t):
    n, d = A.shape
    k, _ = V.shape
    U = np.full((n, k), 1.00)
    reg = np.identity(k)*(lmda**(0.5))
    zeros = np.zeros((k, 1))

    for i in range(n):
        VD = (V @ rowweights[i])
        AD = (A[[i], : ] @ rowweights[i])

        M = np.block([[VD.T], [reg]])
        b = np.block([[AD.T], [zeros]])

        U[i,:] = la.lstsq(M, b, rcond = None)[0].T

    return U


def altmin(colweights, rowweights, A, t, k, lmda, num_updates, sketch):
    n, d = A.shape
    U = A[:, : k]
    V = A[ : k, : ]

    if(sketch):
        return alternating_minimization(U, V, lambda V: update_left(rowweights, A, V, lmda, t), lambda U: update_right(colweights, A, U, lmda, t), num_updates)
    else:
        return alternating_minimization(U, V, lambda V: nosketch_update_left(rowweights, A, V, lmda, t), lambda U: nosketch_update_right(colweights, A, U, lmda, t), num_updates)

def rwlra(W, A, t, k, lmda, outf, sketch):
    n, d = A.shape
    num_updates = 25

    col_start = time.time()
    colweights = [sparse.csr_matrix(np.diag(W[:, j])) for j in range(d)]
    col_end = time.time()
    rowweights = [sparse.csr_matrix(np.diag(W[i, : ])) for i in range(n)]
    row_end = time.time()

    results = altmin(colweights, rowweights, A, t, k, lmda, num_updates, sketch)
    for (U,V) in results[-1 : ]:
        print("Altmin objective value is ", (la.norm(W*(U@V - A), "fro")**2) + lmda * (la.norm(U, "fro")**2) + lmda * (la.norm(V, "fro"))**2, file=outf)
        if(sketch):
            print("Used sketching", file=outf)
        else:
            print("Did not sketch", file=outf)

def wsvd(W, A, k, lmda, outf):
    n, d = A.shape

    U, D, V = la.svd(A, full_matrices = False)
    l = D.shape[0]
    l = min(l, k)

    U = U @ np.diag(D)
    print("SVD objective value is ", la.norm(W*(U[:, : l] @ V[ : l, : ] - A), "fro")**2 + lmda * la.norm(U[:, : l], "fro")**2 + lmda * la.norm(V[ : l, : ], "fro")**2, file=outf)

def gen_synthetic(r, c, rank, lmda):
    atime = time.time()

    U = la.qr(np.random.randn(r, r))[0][:, : rank]
    V = la.qr(np.random.randn(c, c))[0][ : rank, :]

    sing = [(lmda / rank)**(0.5) for _ in range(rank)]
    sing[0] = r
    D = np.diag(sing)

    return U @ D @ V

def top_singular(A):
    U, D, V = la.svd(A, full_matrices = False)

    return D[0]

def estimate_sd(A, lmda):
    U, D, V = la.svd(A, full_matrices = False)
    sd = 0

    for d in D:
        sd += 1.0 / (1.0 + (lmda / (d**2)))

    return sd

def gen_weights(rows, cols):
    weights = [1, 0.1, 0.01]
    probs = [0.85, 0.1, 0.05]

    return np.reshape(np.random.choice(weights, rows * cols, p = probs), (rows, cols))

def train(n, d, t, k, lmda, outf):
    print("(n, d, t, k, lmda) = ({}, {}, {}, {}, {})".format(n, d, t, k, lmda), file=outf)
    
    # synthetic
    # A = gen_synthetic(n, d, d, lmda)

    # real
    matfile = "landmark.mtx"
    B = io.mmread(matfile)
    B = B.tocsr()
    B = B[np.random.choice(B.shape[0], n, replace=False)]

    A = skl.rbf_kernel(B)
    print("sd estimate", estimate_sd(A, lmda), file=outf)
    print("Frob norm estimate", la.norm(A, "fro"), file=outf)

    weight_start = time.time()
    W = gen_weights(A.shape[0], A.shape[1])
    weight_end = time.time()
    print("Took", round(weight_end - weight_start, 2), "seconds to generate weight matrix", file=outf)

    svd_start = time.time()
    wsvd(W, A, k, lmda, outf)
    svd_end = time.time()
    print("SVD took", round(svd_end - svd_start, 2), "seconds", file=outf)

    sketch_start = time.time()
    rwlra(W, A, t, t, lmda, outf, True)
    sketch_end = time.time()
    print("Sketching took", round(sketch_end - sketch_start, 2), "seconds", file=outf)

    nosketch_start = time.time()
    rwlra(W, A, t, k, lmda, outf, False)
    nosketch_end = time.time()
    print("Not sketching took", round(nosketch_end - nosketch_start, 2), "seconds", file=outf)

    print(".........................", file=outf)


